"""Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models"""

import requests
import hashlib
import os
import torch
import torch.nn as nn
from torchvision import models
from collections import namedtuple
from tqdm import tqdm
from image_synthesis.modeling.modules.vqgan_loss.moco_net import resnet50

# from taming.util import get_ckpt_path

class MOCOPER(nn.Module):
    # Learned perceptual metric
    def __init__(self, use_dropout=True):
        super().__init__()
        self.scaling_layer = ScalingLayer()        
        #self.moco_r50_model = moco_net.__dict__['resnet50']()
        self.moco_r50_model = resnet50()
        '''
        checkpoint_path = "OUTPUT/r-50-1000ep.pth.tar"
        checkpoint = torch.load(checkpoint_path, map_location="cpu")
        state_dict = checkpoint['state_dict']
        for k in list(state_dict.keys()):
            #if k.startswith('module.encoder_q') and not k.startswith('module.encoder_q.fc'):
            if k.startswith('module.base_encoder.') and not k.startswith('module.base_encoder.head'):
                state_dict[k[len("module.base_encoder."):]] = state_dict[k]
            del state_dict[k]
        msg = self.moco_r50_model.load_state_dict(state_dict, strict=False)
        #assert set(msg.missing_keys) == {"fc.weight", "fc.bias"}
        print(msg)
        '''
        self.moco_r50_model.fc = torch.nn.Identity()

        for param in self.moco_r50_model.parameters():
            param.requires_grad = False

        self.trainable = False
    # def train(self, mode=True):
    #     pass

    def train(self, mode=True):
        if self.trainable and mode:
            return super().train(True)
        else:
            return super().train(False)


    def load_from_pretrained(self, name="vgg_lpips"):
        ckpt = get_ckpt_path(name)
        self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
        print("loaded pretrained LPIPS loss from {}".format(ckpt))

    @classmethod
    def from_pretrained(cls, name="vgg_lpips"):
        if name != "vgg_lpips":
            raise NotImplementedError
        model = cls()
        ckpt = get_ckpt_path(name)
        model.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
        return model

    def forward(self, input, target):
        in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target))
        outs0, outs1 = self.moco_r50_model(in0_input, index_list=[0,1,2,3,4]), self.moco_r50_model(in1_input, index_list=[0,1,2,3,4])
        diffs = {}        
        for kk in range(len(outs0)):            
            diffs[kk] = (outs0[kk] - outs1[kk]) ** 2

        res = [spatial_average(diffs[kk], keepdim=True) for kk in range(len(diffs))]
        val = res[0]
        for l in range(1, len(diffs)):
            val += res[l]
        return val


class ScalingLayer(nn.Module):
    def __init__(self):
        super(ScalingLayer, self).__init__()
        self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
        self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None])

    def forward(self, inp):
        height = inp.shape[2]
        if height > 256:
            inp = nn.functional.interpolate(inp, size=(256, 256))
        return (inp - self.shift) / self.scale



def normalize_tensor(x,eps=1e-10):
    """
    Get the norm along channel dimension
    """
    norm_factor = torch.sqrt(torch.sum(x**2,dim=1,keepdim=True))
    return x/(norm_factor+eps)


def spatial_average(x, keepdim=True):
    #return x.mean([2,3],keepdim=keepdim)
    return x.mean([1,2,3],keepdim=keepdim)

